import logging
import os
import pickle
import re
import tempfile
import time
import uuid
from contextlib import contextmanager
from pathlib import Path
from typing import List, Optional, Tuple

import pyarrow.fs
import pytest

import ray
import ray.train
import ray.train.collective
from ray._common.test_utils import simulate_s3_bucket
from ray.air._internal.uri_utils import URI
from ray.train import (
    Checkpoint,
    CheckpointConfig,
    FailureConfig,
    RunConfig,
    ScalingConfig,
)
from ray.train.v2._internal.constants import HEALTH_CHECK_INTERVAL_S_ENV_VAR
from ray.train.v2._internal.execution.storage import _download_from_fs_path
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer


class TestConstants:
    NUM_ITERATIONS = 6  # == num_checkpoints == num_artifacts
    NUM_TRIALS = 2
    NUM_WORKERS = 3

    SCORE_KEY = "score"


@contextmanager
def mock_s3_bucket_uri():
    port = 5002
    region = "us-west-2"
    with simulate_s3_bucket(port=port, region=region) as s3_uri:
        import boto3

        s3 = boto3.client(
            "s3", region_name=region, endpoint_url=f"http://localhost:{port}"
        )
        # Bucket name will be autogenerated/unique per test
        bucket_name = URI(s3_uri).name
        s3.create_bucket(
            Bucket=bucket_name,
            CreateBucketConfiguration={"LocationConstraint": region},
        )
        # Disable server HTTP request logging
        logging.getLogger("werkzeug").setLevel(logging.WARNING)
        yield URI(s3_uri)
        logging.getLogger("werkzeug").setLevel(logging.INFO)


@contextmanager
def dummy_context_manager(*args, **kwargs):
    yield "dummy value"


@pytest.fixture(autouse=True, scope="module")
def ray_start_4_cpus():
    ray.init(num_cpus=4)
    yield
    ray.shutdown()


def _create_mock_custom_fs(custom_fs_root_dir: Path) -> pyarrow.fs.FileSystem:
    from fsspec.implementations.dirfs import DirFileSystem
    from fsspec.implementations.local import LocalFileSystem

    custom_fs_root_dir.mkdir(parents=True, exist_ok=True)
    storage_filesystem = pyarrow.fs.PyFileSystem(
        pyarrow.fs.FSSpecHandler(
            DirFileSystem(path=str(custom_fs_root_dir), fs=LocalFileSystem())
        )
    )
    return storage_filesystem


@contextmanager
def _resolve_storage_type(
    storage_path_type: str, tmp_path: Path
) -> Tuple[str, Optional[pyarrow.fs.FileSystem]]:
    storage_path, storage_filesystem = None, None

    context_manager = (
        mock_s3_bucket_uri if storage_path_type == "cloud" else dummy_context_manager
    )

    with context_manager() as cloud_storage_path:
        if storage_path_type == "nfs":
            storage_path = str(tmp_path / "fake_nfs")
        elif storage_path_type == "cloud":
            storage_path = str(cloud_storage_path)
        elif storage_path_type == "custom_fs":
            storage_path = "mock_bucket"
            storage_filesystem = _create_mock_custom_fs(tmp_path / "custom_fs")

        yield storage_path, storage_filesystem


def _get_local_inspect_dir(
    root_local_path: Path,
    storage_path: str,
    storage_filesystem: Optional[pyarrow.fs.FileSystem],
    storage_local_path: Path = None,
) -> Tuple[Path, str]:
    """Downloads the storage path -> local dir for inspecting contents.

    Returns:
        Tuple: (local_inspect_dir, storage_fs_path), where storage_fs_path
            is the path to the storage path on the filesystem (e.g., prefix stripped).
            This is used to check the correctness of paths returned from `Result`'s,
            since URIs are hard to do comparisons with.
    """
    local_inspect_dir = root_local_path / "inspect"
    if storage_path:
        if storage_filesystem:
            fs, storage_fs_path = storage_filesystem, storage_path
        else:
            fs, storage_fs_path = pyarrow.fs.FileSystem.from_uri(storage_path)
        _download_from_fs_path(
            fs=fs, fs_path=storage_fs_path, local_path=str(local_inspect_dir)
        )
    else:
        fs, storage_fs_path = pyarrow.fs.LocalFileSystem(), str(storage_local_path)
        local_inspect_dir = storage_local_path

    return local_inspect_dir, storage_fs_path


def _get_checkpoint_epoch(checkpoint_dir_name: str) -> int:
    """Gets the checkpoint index from the checkpoint directory name."""
    pattern = r"checkpoint_epoch=(\d+)"
    match = re.search(pattern, checkpoint_dir_name)
    assert match
    return int(match.group(1))


def _create_checkpoint_shard_filename(rank_str: str) -> str:
    return f"checkpoint_shard-rank={rank_str}.pkl"


def _get_checkpoint_shard_rank(checkpoint_shard_filename: str) -> int:
    """Get the checkpoint shard rank from the filename."""
    pattern = _create_checkpoint_shard_filename(r"(\d+)")
    match = re.search(pattern, checkpoint_shard_filename)
    assert match
    return int(match.group(1))


def train_fn(config):
    # Check that the working dir for each worker is the shared trial dir.
    # assert Path.cwd() == Path(train_session.storage.trial_working_directory).resolve()

    start = 0

    checkpoint = ray.train.get_checkpoint()
    if checkpoint:
        custom_restore_fn = config.get("custom_restore_fn")
        if custom_restore_fn:
            state = custom_restore_fn(checkpoint)
        else:
            with checkpoint.as_directory() as checkpoint_dir:
                with open(os.path.join(checkpoint_dir, "checkpoint.pkl"), "rb") as f:
                    state = pickle.load(f)
        print("Loaded back state from checkpoint:", state)
        start = state["iter"] + 1

    got = len(ray.train.get_all_reported_checkpoints())
    expected = min(start, config.get("num_to_keep", float("inf")))
    assert got == expected, f"Expected {expected} checkpoints, got {got}"

    for i in range(start, config.get("num_iterations", 5)):
        time.sleep(config.get("time_per_iter", 0.25))

        metrics = {"iter": i, TestConstants.SCORE_KEY: i}

        rank = ray.train.get_context().get_world_rank()
        if rank in config.get("no_checkpoint_ranks", []):
            ray.train.report(
                metrics, checkpoint=None, checkpoint_dir_name=f"checkpoint_epoch={i}"
            )
        else:
            with tempfile.TemporaryDirectory() as temp_dir:
                with open(os.path.join(temp_dir, "checkpoint.pkl"), "wb") as f:
                    pickle.dump({"iter": i}, f)

                checkpoint_file_name = _create_checkpoint_shard_filename(str(rank))
                with open(os.path.join(temp_dir, checkpoint_file_name), "wb") as f:
                    pickle.dump({"iter": i}, f)

                with config.get("custom_save_fn", dummy_context_manager)(temp_dir):
                    ray.train.report(
                        metrics,
                        checkpoint=Checkpoint.from_directory(temp_dir),
                        checkpoint_dir_name=f"checkpoint_epoch={i}",
                    )
                # `train.report` should not have deleted this!
                assert os.path.exists(temp_dir)

        # TODO: This barrier before raising is a workaround to deflake the test.
        # In this test setup, rank 0 is the fast-reporting worker
        # that does not upload a checkpoint.
        # If rank 0 raises an error immediately after getting past `report`,
        # the next iteration of the control loop will handle the failure
        # and the checkpoints from all other ranks will not be processed.
        # This results in an earlier checkpoint getting used during restoration,
        # which will cause the test assertions to fail.
        # This should be fixed by forcing a queue flush on all workers before
        # executing the failure decisions.
        ray.train.collective.barrier()

        if i in config.get("fail_iters", []):
            got = len(ray.train.get_all_reported_checkpoints())
            expected = min(i + 1, config.get("num_to_keep", float("inf")))
            assert got == expected, f"Expected {expected} checkpoints, got {got}"
            raise RuntimeError(f"Failing on iter={i}!!")


def _assert_storage_contents(
    local_inspect_dir: Path,
    exp_name: str,
    checkpoint_config: CheckpointConfig,
    no_checkpoint_ranks: List[int] = None,
    constants: type = TestConstants,
):
    no_checkpoint_ranks = no_checkpoint_ranks or []

    # Second, inspect the contents of the storage path
    storage_path_ls = list(local_inspect_dir.glob("*"))
    assert len(storage_path_ls) == 1  # Only expect 1 experiment dir
    exp_dir = storage_path_ls[0]
    assert exp_dir.name == exp_name

    # Check checkpoint contents
    # If set, expect num_to_keep. Otherwise, expect to see all of them.
    expected_num_checkpoints = checkpoint_config.num_to_keep or constants.NUM_ITERATIONS

    assert len(list(exp_dir.glob("checkpoint_epoch=*"))) == expected_num_checkpoints
    checkpoint_epochs = sorted(
        [
            _get_checkpoint_epoch(checkpoint_dir.name)
            for checkpoint_dir in exp_dir.glob("checkpoint_epoch=*")
        ]
    )
    # Ex: If num_to_keep=2 out of 6 total checkpoints,
    # expect checkpoint_epoch=4 and checkpoint_epoch=5.
    assert checkpoint_epochs == list(
        range(
            constants.NUM_ITERATIONS - expected_num_checkpoints,
            constants.NUM_ITERATIONS,
        )
    )

    for checkpoint_dir in exp_dir.glob("checkpoint_epoch=*"):
        # 1 shared checkpoint.pkl file, written by the trainable / all workers.
        assert len(list(checkpoint_dir.glob("checkpoint.pkl"))) == 1
        if test_trainer:
            # 1 checkpoint shard per worker.
            # Unless the worker did not report a checkpoint (no_checkpoint_ranks).
            assert {
                _get_checkpoint_shard_rank(checkpoint_shard.name)
                for checkpoint_shard in checkpoint_dir.glob("checkpoint_shard-*.pkl")
            } == {
                i for i in range(constants.NUM_WORKERS) if i not in no_checkpoint_ranks
            }


@pytest.mark.parametrize("storage_path_type", ["nfs", "cloud", "custom_fs"])
@pytest.mark.parametrize(
    "checkpoint_config",
    [
        CheckpointConfig(),
        CheckpointConfig(
            num_to_keep=1,
            checkpoint_score_attribute=TestConstants.SCORE_KEY,
            checkpoint_score_order="max",
        ),
    ],
)
def test_trainer(
    monkeypatch, tmp_path, storage_path_type, checkpoint_config: CheckpointConfig
):
    """End-to-end test that runs Train with many `storage_path_type` options:
    - storage_path="nfs" --> save locally to a fake NFS path
    - storage_path="cloud" --> save to a mock S3 bucket
    - storage_path="custom_fs" --> save to a custom pyarrow filesystem
        - The custom fs is a local filesystem that appends a path prefix to every path.

    This is the expected output at the storage path:

    {RunConfig.storage_path}/{RunConfig.name}
    └── checkpoint_epoch={epoch}              <- Checkpoint directories with custom name
        ├── checkpoint.pkl                    <- Shared checkpoint file
        ├── checkpoint_shard-rank=0.pkl       <- Worker checkpoint shards
        └── checkpoint_shard-rank=1.pkl
    └── ...
    """
    health_check_interval_s = 0.1
    monkeypatch.setenv(HEALTH_CHECK_INTERVAL_S_ENV_VAR, str(health_check_interval_s))
    # Make report time slightly longer than health check interval.
    # This is arbitrary but is meant to mimic a somewhat realistic scenario.
    time_between_reports = health_check_interval_s * 2

    exp_name = f"trainer_persistence_test-{uuid.uuid4().hex}"
    no_checkpoint_ranks = [0]
    if checkpoint_config.num_to_keep:
        num_to_keep = checkpoint_config.num_to_keep
    else:
        num_to_keep = float("inf")

    with _resolve_storage_type(storage_path_type, tmp_path) as (
        storage_path,
        storage_filesystem,
    ):
        run_config = RunConfig(
            storage_path=storage_path,
            storage_filesystem=storage_filesystem,
            name=exp_name,
            checkpoint_config=checkpoint_config,
            failure_config=FailureConfig(max_failures=2),
        )
        trainer = DataParallelTrainer(
            train_fn,
            train_loop_config={
                "num_iterations": TestConstants.NUM_ITERATIONS,
                "fail_iters": [2, 4],
                # Test that global rank 0 is not required to checkpoint.
                "no_checkpoint_ranks": no_checkpoint_ranks,
                "time_per_iter": time_between_reports,
                "num_to_keep": num_to_keep,
            },
            scaling_config=ScalingConfig(num_workers=TestConstants.NUM_WORKERS),
            run_config=run_config,
        )
        print("\nStarting initial run.\n")
        result = trainer.fit()

        print("\nStarting manually restored run.\n")
        restored_trainer = DataParallelTrainer(
            train_fn,
            train_loop_config={
                "num_iterations": TestConstants.NUM_ITERATIONS,
                "fail_iters": [2, 4],
                # Test that global rank 0 is not required to checkpoint.
                "no_checkpoint_ranks": no_checkpoint_ranks,
                "time_per_iter": time_between_reports,
                "num_to_keep": num_to_keep,
            },
            scaling_config=ScalingConfig(num_workers=TestConstants.NUM_WORKERS),
            run_config=run_config,
        )
        result = restored_trainer.fit()

        local_inspect_dir, storage_fs_path = _get_local_inspect_dir(
            root_local_path=tmp_path,
            storage_path=run_config.storage_path,
            storage_filesystem=storage_filesystem,
        )

    # First, inspect that the result object returns the correct paths.
    print(result)
    run_path = result.path
    assert run_path.startswith(storage_fs_path)
    for checkpoint, _ in result.best_checkpoints:
        assert checkpoint.path.startswith(run_path)

    _assert_storage_contents(
        local_inspect_dir,
        exp_name,
        checkpoint_config,
        no_checkpoint_ranks=no_checkpoint_ranks,
    )


if __name__ == "__main__":
    import sys

    sys.exit(pytest.main(["-v", "-x", __file__]))
